# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

from __future__ import annotations

import dataclasses

from symforce import codegen
from symforce import typing as T
from symforce.codegen import codegen_config
from symforce.values import Values


@dataclasses.dataclass
class SimilarityIndex:
    """
    Contains all the information needed to assess if two Codegen objects
    would generate the same function, modulo function name and docstring.

    WARNING: SimilarityIndex is hashable despite being mutable. This means
    you should be careful when storing it as a key of a dict, as ordinary
    keys are immutable.
    """

    config: codegen_config.CodegenConfig
    inputs: Values
    outputs: Values
    return_key: T.Optional[str]
    # NOTE(brad): Only the keys are needed because Codegen will generate sparse_mat_data
    # to be the same if both the keys and the outputs are the same between two objects.
    sorted_sparse_matrices: T.Tuple[str, ...] = dataclasses.field(init=False)
    sparse_matrices: dataclasses.InitVar[T.Iterable[str]]

    def __post_init__(self, sparse_matrices: T.List[str]) -> None:
        self.sorted_sparse_matrices = tuple(sorted(sparse_matrices))

    @staticmethod
    def from_codegen(co: codegen.Codegen) -> SimilarityIndex:
        """
        Returns the SimilarityIndex of a Codegen object.

        If co1 and co2 are two Codegen objects, then
        from_codegen(co1) == from_codegen(co2) if and only if the function
        generated by co1.generate_function() is the same as that of co2.generate_function()
        (up to differences in function name and doc-strings).
        """
        return SimilarityIndex(
            inputs=co.inputs,
            outputs=co.outputs,
            config=co.config,
            return_key=co.return_key,
            sparse_matrices=co.sparse_mat_data.keys(),
        )

    def __hash__(self) -> int:
        """
        WARNING: SimilarityIndex is mutable, and you must be mindful of the fact that
        keys of dicts are supposed to be immutable.

        If seeking to use a SimilarityIndex in a dict as a key, encapsulate this to
        make sure others aren't able to modify the object after it has been hashed.
        """
        return hash(
            (
                tuple(self.inputs.to_storage()),
                tuple(self.outputs.to_storage()),
                self.return_key,
                self.sorted_sparse_matrices,
                # Convert to key, value tuples recursively.  Unlike astuple, this has field names
                dataclasses.asdict(
                    self.config, dict_factory=T.cast(T.Callable[[T.List], T.Tuple], tuple)
                ),
            )
        )
